import hydra
import random
import os
from tqdm import tqdm
import numpy as np

from dataset.couple import CoupleData
from utils.utils import get_logger
from utils.writer import Writer
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from utils.utils import set_random_seed
from omegaconf import open_dict


from model.PMMR.model import PMMRModel



def save_array_to_npy(array, filename_prefix, cfg):
    save_filename = f"{cfg.mdl_name}_{cfg.dataset.type}_{cfg.dataset.num}_{filename_prefix}"
    save_path = os.path.join(cfg.log.ATE, save_filename)
    np.save(save_path, array)


def causal_mae(pre, tar) -> float:
    difference = pre - tar
    row_average = np.mean(difference**2, axis=1)
    return row_average


def Experiment(cfg):
    logger = get_logger(cfg, os.path.basename(__file__))
    writer = Writer(cfg, "tensorboard")
    os.makedirs(cfg.log.chkpt_dir, exist_ok=True)
    cfg_str = OmegaConf.to_yaml(cfg)
    logger.info("Config:\n" + cfg_str)
    if cfg.data.train_dir == "" or cfg.data.test_dir == "":
        logger.error("train or test data directory cannot be empty.")
        raise Exception("Please specify directories of data")
    logger.info("Set up train process")
    logger.info("Making train dataset...")

    seeds = np.array([1009, 1109, 1656, 1816, 2029,
                    2297, 2533, 2847, 4759, 4379,
                    4388, 4987, 5518, 5654, 5979,
                    7422, 7987, 8455, 9783, 9886])

    
    if cfg.mdl_name == 'rkhs':
        
        ATE_h_list,ATE_q_list,ATE_dr_list = [],[],[]

        treatment, tar = CoupleData.generate_effect_exm1(0,1,10)
        # Create dataset
        if cfg.dataset.type == 'couple':
            for seed in tqdm(seeds):
                train_dataset = CoupleData(seed,cfg.dataset.num).generatate_coup()
                test_dataset = CoupleData.generate_test(1000,seed+1)

                pmmr_train = PMMRModel()
                


                # A_4 -> Y_3  Z: Y_2, W: A_3, no X  lamh1=0.2, lamh2=0.2, lamq1=1, lamq2=1, scale=0.25
                # lamh1=0.2, lamh2=0.2, lamq1=0.2, lamq2=0.2, scale=0.25
                
                # W = train_dataset[:,2][:, np.newaxis]
                # Z = train_dataset[:,6][:, np.newaxis]
                # A = train_dataset[:,3][:, np.newaxis]
                # Y = train_dataset[:,7][:, np.newaxis]

    

                # # Train h
                # pmmr_train.fit_h(A,W,Z,Y,X=None)
                # # # Test h
                # W_test = test_dataset[:,2][:, np.newaxis]
                # ATE_h = pmmr_train.predict_h(treatment[:, np.newaxis], W_test)
                # print(ATE_h[1]-ATE_h[0])
                # ATE_h_list.append(ATE_h)  


                # # Train q
                # pmmr_train.fit_q(A,W,Z,X=None)
                # # # Test q
                # Z_test = test_dataset[:,6][:, np.newaxis]
                # A_test = test_dataset[:,3][:, np.newaxis]
                # Y_test = test_dataset[:,7][:, np.newaxis]
            
                # ATE_q = pmmr_train.predict_q(treatment,A_test,Z_test,Y_test,X=None)
                # ATE_q_list.append(ATE_q)  
    
            

                # # # DR
                # ATE_dr = pmmr_train.drtest(treatment,A_test,Z_test,W_test,Y_test,X=None)
                # print(ATE_dr[1]-ATE_dr[0])
                # ATE_dr_list.append(ATE_dr) 

                # (A_3,A_4) -> Y_1  Z: Y_3, W: A_5, lamh1=0.2, lamh2=0.2, lamq1=0.05, lamq2=0.05, scale=0.5

                # W = train_dataset[:,4][:, np.newaxis]
                # Z = train_dataset[:,7][:, np.newaxis]
                # A = train_dataset[:,[2,3]]
                # Y = train_dataset[:,5][:, np.newaxis]

                # pmmr_train.fit_h(A,W,Z,Y,X=None)

                # W_test = test_dataset[:,4][:, np.newaxis]

                # ATE_h = pmmr_train.predict_h_mul(treatment, W_test)
                # ATE_h_list.append(ATE_h)  

                # pmmr_train.fit_q(A,W,Z,X=None)
                # # Test q
                # Z_test = test_dataset[:,7][:, np.newaxis]
                # A_test = test_dataset[:,[2,3]]
                # Y_test = test_dataset[:,5][:, np.newaxis]
            
                # ATE_q = pmmr_train.predict_q_mul(treatment,A_test,Z_test,Y_test,X=None)
                # ATE_q_list.append(ATE_q)  


                # # DR
                # ATE_dr = pmmr_train.drtest_mul(treatment,A_test,Z_test,W_test,Y_test,X=None)
                # print(ATE_dr[1]-ATE_dr[0])
                # ATE_dr_list.append(ATE_dr) 



                # (A_1, A_3) -> Y_1  Z: Y_3, W: A_5,  lamh1=0.2, lamh2=0.2, lamq1=1, lamq2=1, scale=0.6

                # W = train_dataset[:,4][:, np.newaxis]
                # Z = train_dataset[:,7][:, np.newaxis]
                # A = train_dataset[:,[0,2]]
                # Y = train_dataset[:,5][:, np.newaxis]

                
                # # Train h
                # pmmr_train.fit_h(A,W,Z,Y,X=None)
                # # Test h

                # W_test = test_dataset[:,4][:, np.newaxis]

                # ATE_h = pmmr_train.predict_h_mul(treatment[:, np.newaxis], W_test)
                # print(ATE_h[1]-ATE_h[0])
                # ATE_h_list.append(ATE_h)  

                # # # Train q
                # pmmr_train.fit_q(A,W,Z,X=None)
                # # Test q
                # Z_test = test_dataset[:,7][:, np.newaxis]
                # A_test = test_dataset[:,[0,2]]
                # Y_test = test_dataset[:,5][:, np.newaxis]
            
                # ATE_q = pmmr_train.predict_q_mul(treatment,A_test,Z_test,Y_test,X=None)
                # ATE_q_list.append(ATE_q)  
            

                # # DR
                # ATE_dr = pmmr_train.drtest_mul(treatment,A_test,Z_test,W_test,Y_test,X=None)
                # print(ATE_dr[1]-ATE_dr[0])
                # ATE_dr_list.append(ATE_dr) 



                # A_2 -> Y_2  Z: Y_3, W: A_5, no X  lamh1=0.2, lamh2=0.2, lamq1=1, lamq2=1, scale=0.25
                # lamh1=0.2, lamh2=0.2, lamq1=0.2, lamq2=0.2, scale=0.25
                
                W = train_dataset[:,4][:, np.newaxis]
                Z = train_dataset[:,7][:, np.newaxis]
                A = train_dataset[:,1][:, np.newaxis]
                Y = train_dataset[:,6][:, np.newaxis]

    

                # Train h
                pmmr_train.fit_h(A,W,Z,Y,X=None)
                # # Test h
                W_test = test_dataset[:,4][:, np.newaxis]
                ATE_h = pmmr_train.predict_h(treatment[:, np.newaxis], W_test)
                ATE_h_list.append(ATE_h)  


                # Train q
                pmmr_train.fit_q(A,W,Z,X=None)
                # # Test q
                Z_test = test_dataset[:,7][:, np.newaxis]
                A_test = test_dataset[:,1][:, np.newaxis]
                Y_test = test_dataset[:,6][:, np.newaxis]
            
                ATE_q = pmmr_train.predict_q(treatment,A_test,Z_test,Y_test,X=None)
                ATE_q_list.append(ATE_q)  
            

                # # DR
                ATE_dr = pmmr_train.drtest(treatment,A_test,Z_test,W_test,Y_test,X=None)

                ATE_dr_list.append(ATE_dr) 


                # (A_1, A_5) -> Y_4  Z: Y_2, W: A_3, no X  hlamh1=0.2, lamh2=0.2, lamq1=0.2, lamq2=0.2, scale=0.5

                # W = train_dataset[:,2][:, np.newaxis]
                # Z = train_dataset[:,6][:, np.newaxis]
                # A = train_dataset[:,[0,4]]
                # Y = train_dataset[:,8][:, np.newaxis]

                
                # # Train h
                # pmmr_train.fit_h(A,W,Z,Y,X=None)
                # # Test h

                # W_test = test_dataset[:,2][:, np.newaxis]
                # ATE_h = pmmr_train.predict_h_mul(treatment[:, np.newaxis], W_test)
                # ATE_h_list.append(ATE_h)  

                # # # Train q
                # pmmr_train.fit_q(A,W,Z,X=None)
                # # Test q
                # Z_test = test_dataset[:,6][:, np.newaxis]
                # A_test = test_dataset[:,[0,4]]
                # Y_test = test_dataset[:,8][:, np.newaxis]
            
                # ATE_q = pmmr_train.predict_q_mul(treatment,A_test,Z_test,Y_test,X=None)
                # ATE_q_list.append(ATE_q)  
            
                # ATE_q_prime_list.append(ATE_q_prime) 

                # # DR
                # ATE_dr = pmmr_train.drtest_mul(treatment,A_test,Z_test,W_test,Y_test,X=None)
                # ATE_dr_list.append(ATE_dr) 


                # A_3 -> Y_1  Z: Y_3, W: A_5,  no X  lamh1=0.05, lamh2=0.05, lamq1=0.2, lamq2=0.2, scale=0.5 ,h=1.5
                
                # W = train_dataset[:,4][:, np.newaxis]
                # Z = train_dataset[:,7][:, np.newaxis]
                # A = train_dataset[:,2][:, np.newaxis]
                # Y = train_dataset[:,5][:, np.newaxis]
            
    

                # # # Train h
                # pmmr_train.fit_h(A,W,Z,Y,X=None)
                # # # Test h
                # W_test = test_dataset[:,4][:, np.newaxis]
                
                # ATE_h = pmmr_train.predict_h(treatment[:, np.newaxis], W_test)
                # ATE_h_list.append(ATE_h)  


                # # # Train q
                # pmmr_train.fit_q(A,W,Z,X=None)
                # # # Test q
                # Z_test = test_dataset[:,7][:, np.newaxis]
                # A_test = test_dataset[:,2][:, np.newaxis]
                # Y_test = test_dataset[:,5][:, np.newaxis]
            
                # ATE_q = pmmr_train.predict_q(treatment,A_test,Z_test,Y_test,X=None)
                # ATE_q_list.append(ATE_q)  

                # # # DR
                # ATE_dr = pmmr_train.drtest(treatment,A_test,Z_test,W_test,Y_test,X=None)
                # print(ATE_dr[1]-ATE_dr[0])
                # ATE_dr_list.append(ATE_dr) 

       

                


    save_array_to_npy(tar, "Groud_Truth", cfg)
    save_array_to_npy(np.array(ATE_h_list), "ATE_h", cfg)
    save_array_to_npy(np.array(ATE_q_list), "ATE_q", cfg)
    save_array_to_npy(np.array(ATE_dr_list), "ATE_dr", cfg)





@hydra.main(version_base="1.2", config_path="config", config_name="default")
def main(hydra_cfg):
    
    # print(hydra_cfg)
    with open_dict(hydra_cfg):
        hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging
    
    # random seed
    if hydra_cfg.random_seed is None:
        hydra_cfg.random_seed = random.randint(1, 10000)
    
    set_random_seed(hydra_cfg.random_seed)
    
    Experiment(hydra_cfg)
    

if __name__ == "__main__":
    main()
